import torch
import os
from flask import Flask, Response, jsonify, request, Blueprint
from flask_restful import Api, Resource
import pickle
import argparse
import threading
import argparse


device = f'cuda:{torch.cuda.device_count()-1}'
torch.cuda.set_device(device)
dtype = torch.bfloat16

def parsed_args():
    parser = argparse.ArgumentParser(description="StepVideo API Functions")
    parser.add_argument('--model_dir', type=str)
    parser.add_argument('--clip_dir', type=str, default='hunyuan_clip')
    parser.add_argument('--llm_dir', type=str, default='step_llm')
    parser.add_argument('--vae_dir', type=str, default='vae')
    parser.add_argument('--port', type=str, default='8080')
    args = parser.parse_args()
    return args



class StepVaePipeline(Resource):
    def __init__(self, vae_dir, version=2):
        self.vae = self.build_vae(vae_dir, version)
        self.scale_factor = 1.0

    def build_vae(self, vae_dir, version=2):
        from stepvideo.vae.vae import AutoencoderKL
        (model_name, z_channels) = ("vae_v2.safetensors", 64) if version == 2 else ("vae.safetensors", 16)
        model_path = os.path.join(vae_dir, model_name) 
        
        model = AutoencoderKL(
            z_channels=z_channels,
            model_path=model_path,
            version=version,
        ).to(dtype).to(device).eval()
        print("Inintialized vae...")
        return model
 
    def decode(self, samples, *args, **kwargs):
        with torch.no_grad():
            try:
                dtype = next(self.vae.parameters()).dtype
                device = next(self.vae.parameters()).device
                samples = self.vae.decode(samples.to(dtype).to(device) / self.scale_factor)
                if hasattr(samples,'sample'):
                    samples = samples.sample
                return samples
            except Exception as err:
                print(f"vae decode error: {err}")
                torch.cuda.empty_cache()
                return None
            
    def encode(self, videos, *args, **kwargs):
        with torch.no_grad():
            try:
                dtype = next(self.vae.parameters()).dtype
                device = next(self.vae.parameters()).device
                latents = self.vae.encode(videos.to(dtype).to(device))*self.scale_factor
                if hasattr(latents,'sample'):
                    latents = latents.sample
                return latents
            except Exception as err:
                print(f"vae encode error: {err}")
                torch.cuda.empty_cache()
                return None

lock = threading.Lock()
class VAEapi(Resource):
    def __init__(self, vae_pipeline):
        self.vae_pipeline = vae_pipeline
        
    def get(self):
        with lock:
            try:
                feature = pickle.loads(request.get_data())
                feature['api'] = 'vae'
            
                feature = {k:v for k, v in feature.items() if v is not None}
                video_latents = self.vae_pipeline.decode(**feature)
                response = pickle.dumps(video_latents)

            except Exception as e:
                print("Caught Exception: ", e)
                return Response(e)
            
            return Response(response)


class VAEEncodeapi(Resource):
    def __init__(self, vae_pipeline):
        self.vae_pipeline = vae_pipeline
        
    def get(self):
        with lock:
            try:
                feature = pickle.loads(request.get_data())
                feature['api'] = 'vae-encode'
            
                feature = {k:v for k, v in feature.items() if v is not None}
                video_latents = self.vae_pipeline.encode(**feature)
                response = pickle.dumps(video_latents)

            except Exception as e:
                print("Caught Exception: ", e)
                return Response(e)
            
            return Response(response)


class CaptionPipeline(Resource):
    def __init__(self, llm_dir, clip_dir):
        self.text_encoder = self.build_llm(llm_dir)
        self.clip = self.build_clip(clip_dir)
        
    def build_llm(self, model_dir):
        from stepvideo.text_encoder.stepllm import STEP1TextEncoder
        text_encoder = STEP1TextEncoder(model_dir, max_length=320).to(dtype).to(device).eval()
        print("Inintialized text encoder...")
        return text_encoder
        
    def build_clip(self, model_dir):
        from stepvideo.text_encoder.clip import HunyuanClip
        clip = HunyuanClip(model_dir, max_length=77).to(device).eval()
        print("Inintialized clip encoder...")
        return clip
 
    def embedding(self, prompts, *args, **kwargs):
        with torch.no_grad():
            try:
                y, y_mask = self.text_encoder(prompts)
                    
                clip_embedding, _ = self.clip(prompts)
                
                len_clip = clip_embedding.shape[1]
                y_mask = torch.nn.functional.pad(y_mask, (len_clip, 0), value=1)   ## pad attention_mask with clip's length 

                data = {
                    'y': y.detach().cpu(),
                    'y_mask': y_mask.detach().cpu(),
                    'clip_embedding': clip_embedding.to(torch.bfloat16).detach().cpu()
                }

                return data
            except Exception as err:
                print(f"{err}")
                return None



lock = threading.Lock()
class Captionapi(Resource):
    def __init__(self, caption_pipeline):
        self.caption_pipeline = caption_pipeline
        
    def get(self):
        with lock:
            try:
                feature = pickle.loads(request.get_data())
                feature['api'] = 'caption'
            
                feature = {k:v for k, v in feature.items() if v is not None}
                embeddings = self.caption_pipeline.embedding(**feature)
                response = pickle.dumps(embeddings)

            except Exception as e:
                print("Caught Exception: ", e)
                return Response(e)
            
            return Response(response)




class RemoteServer(object):
    def __init__(self, args) -> None:
        self.app = Flask(__name__)
        root = Blueprint("root", __name__)
        self.app.register_blueprint(root)
        api = Api(self.app)
        
        self.vae_pipeline = StepVaePipeline(
            vae_dir=os.path.join(args.model_dir, args.vae_dir)
        )
        api.add_resource(
            VAEapi,
            "/vae-api",
            resource_class_args=[self.vae_pipeline],
        )
        
        api.add_resource(
            VAEEncodeapi,
            "/vae-encode-api",
            resource_class_args=[self.vae_pipeline],
        )
        
        self.caption_pipeline = CaptionPipeline(
            llm_dir=os.path.join(args.model_dir, args.llm_dir), 
            clip_dir=os.path.join(args.model_dir, args.clip_dir)
        )
        api.add_resource(
            Captionapi,
            "/caption-api",
            resource_class_args=[self.caption_pipeline],
        )


    def run(self, host="0.0.0.0", port=8080):
        self.app.run(host, port=port, threaded=True, debug=False)


if __name__ == "__main__":
    args = parsed_args()
    flask_server = RemoteServer(args)
    flask_server.run(host="0.0.0.0", port=args.port)
    